# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import math
import sys

import mxnet as mx

import mxnet.ndarray as nd
import numpy as np
import pytest

from gluonts.mx.model.gp_forecaster.gaussian_process import GaussianProcess

from gluonts.mx.context import check_gpu_support
from gluonts.mx.kernels import RBFKernel
from gluonts.mx.linalg_util import jitter_cholesky, jitter_cholesky_eig


# This test verifies that both eigenvalue decomposition and iterative jitter method
# make a non-positive definite matrix positive definite to be able to compute the cholesky.
# Both gpu and cpu as well as single and double precision are tested.
@pytest.mark.skipif(
    sys.platform == "linux",
    reason="skipping since potrf crashes on mxnet 1.6.0 on linux when matrix is not spd",
)
@pytest.mark.parametrize("ctx", ["cpu", "gpu"])
@pytest.mark.parametrize("jitter_method", ["iter", "eig"])
@pytest.mark.parametrize("float_type", [np.float32, np.float64])
def test_jitter_unit(jitter_method, float_type, ctx) -> None:
    # TODO: Enable GPU tests on Jenkins
    if ctx == "gpu" and not check_gpu_support():
        return

    with mx.Context(ctx):
        matrix = nd.array(
            [[[1, 2], [3, 4]], [[10, 100], [-21.5, 41]]], dtype=float_type
        )
        F = mx.nd
        num_data_points = matrix.shape[1]
        if jitter_method == "eig":
            L = jitter_cholesky_eig(F, matrix, num_data_points, float_type)
        elif jitter_method == "iter":
            L = jitter_cholesky(F, matrix, num_data_points, float_type)

    assert np.sum(np.isnan(L.asnumpy())) == 0, "NaNs in Cholesky factor!"


# This test tests that the noiseless sample generated from the synthetic example does not have NaNs in it.
# Without the jitter method, NaNs occurs on the gpu for single and double precision and on the cpu for only single
# precision.  This test verifies that applying the default jitter method fixes these numerical issues on both cpu
# and gpu and for single and double precision.
@pytest.mark.skipif(
    sys.platform == "linux",
    reason="skipping since potrf crashes on mxnet 1.6.0 on linux when matrix is not spd",
)
@pytest.mark.parametrize("ctx", ["cpu", "gpu"])
@pytest.mark.parametrize("jitter_method", ["iter", "eig"])
@pytest.mark.parametrize("float_type", [np.float32, np.float64])
def test_jitter_synthetic_gp(jitter_method, float_type, ctx) -> None:
    # TODO: Enable GPU tests on Jenkins
    if ctx == "gpu" and not check_gpu_support():
        return
    # Initialize problem parameters
    batch_size = 1
    prediction_length = 50
    context_length = 5
    num_samples = 3

    with mx.Context(ctx):
        # Initialize test data to generate Gaussian Process from
        lb = -5
        ub = 5
        dx = (ub - lb) / (prediction_length - 1)
        x_test = nd.arange(lb, ub + dx, dx, dtype=float_type).reshape(-1, 1)
        x_test = nd.tile(x_test, reps=(batch_size, 1, 1))

        # Define the GP hyper parameters
        amplitude = nd.ones((batch_size, 1, 1), dtype=float_type)
        length_scale = math.sqrt(0.4) * nd.ones_like(amplitude)
        sigma = math.sqrt(1e-5) * nd.ones_like(amplitude)

        # Instantiate desired kernel object and compute kernel matrix
        rbf_kernel = RBFKernel(amplitude, length_scale)

        # Generate samples from 0 mean Gaussian process with RBF Kernel and plot it
        gp = GaussianProcess(
            sigma=sigma,
            kernel=rbf_kernel,
            prediction_length=prediction_length,
            context_length=context_length,
            num_samples=num_samples,
            float_type=float_type,
            jitter_method=jitter_method,
            sample_noise=False,  # Returns sample without noise
        )

        # Generate training set on subset of interval using the sine function
        x_train = nd.array([-4, -3, -2, -1, 1], dtype=float_type).reshape(
            context_length, 1
        )
        x_train = nd.tile(x_train, reps=(batch_size, 1, 1))
        y_train = nd.sin(x_train.squeeze(axis=2))

        # Predict exact GP using the GP predictive mean and covariance using the same fixed hyper-parameters
        samples, predictive_mean, predictive_std = gp.exact_inference(
            x_train, y_train, x_test
        )

    assert (
        np.sum(np.isnan(samples.asnumpy())) == 0
    ), "NaNs in predictive samples!"
